from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from nets import resnet_test
from nets import densenet_test
from nets import mobilenet_v2
from nets import shufflenet_v2

nets_map = {
  'resnet_test': resnet_test.ResNet,
  'densenet_test': densenet_test.DenseNet,
  'mobilenet_v2': mobilenet_v2.MobileNet,
  'shufflenet_v2': shufflenet_v2.ShuffleNet,
}

def get_net_fn(name):
  if name not in nets_map:
    raise ValueError('Name of network unknown %s' % name)
  func = nets_map[name]
  return func

  